--- title: Time Alignment with micro-tcn keywords: fastai sidebar: home_sidebar nb_path: "02_time_align.ipynb" ---
Work in progress for NASH Hackathon, Dec 17, 2021
this is like the 01_td_demo notebook only we use a different dataset and generalize the dataloader a bit
# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq pip fastai git+https://github.com/drscotthawley/fastproaudio.git
# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb
# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git git+https://github.com/csteinmetz1/auraloss
# After this cell finishes, restart the kernel and continue below
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from pathlib import Path
from glob import glob
import json
import re
path = Path('wherever jacob puts the data')
fnames_in = sorted(glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob(str(path)+'/*/*targ*'))
ind = -1 # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
Input audio
waveform, sample_rate = torchaudio.load(fnames_in[ind])
show_audio(waveform, sample_rate)
Target output audio
target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Let's look at the difference.
Difference
show_audio(target - waveform, sample_rate)
def get_accompanying_tracks(fn, fn_list, remove=False):
""""Given one filename, and a list of all filenames, return a list of that filename and
any files it 'goes with'
remove: remove these accompanying files from the main list.
"""
# make a copies of fn & fn_list with all hyphen+stuff removed.
basename = re.sub(r'-[a-zA-Z0-9]+','', fn)
basename_list = [re.sub(r'-[a-zA-Z0-9]+','', x) for x in fn_list]
# get indices of all elements of basename_list matching basename, return original filenames
accompanying = [fn_list[i] for i, x in enumerate(basename_list) if x == basename]
if remove:
for x in accompanying:
if x != fn: fn_list.remove(x) # don't remove the file we search on though
return accompanying # note accompanying list includes original file too
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
print(fn_list)
track = fn_list[1]
print("getting matching tracks for ",track)
tracks = get_accompanying_tracks(fn_list[1], fn_list, remove=True)
print("Accompanying tracks are: ",tracks)
print("new list = ",fn_list) # should have the extra 21- tracks removed.
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
fn_list_save = fn_list.copy()
for x in fn_list:
get_accompanying_tracks(x, fn_list, remove=True)
fn_list, fn_list_save
The original dataset class that Christian made, for which we "pack" params and inputs together. This will be loading multichannel wav files
from microtcn.data import SignalTrainLA2ADataset
class SignalTrainLA2ADataset_fastai(SignalTrainLA2ADataset):
"For fastai's sake, have getitem pack the inputs and params together"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, idx):
input, target, params = super().__getitem__(idx)
return torch.cat((input,params),dim=-1), target # pack input and params together
Dataset for loading multiple mono files and packing them together as multichannel:
'''
class MonoToMCDataset(torch.utils.data.Dataset):
"""
UPDATE: turns out we're going to stick to Christian's original dataloader class and just use
conversion scripts to pack or unpack mono WAV files into multichannel WAV files.
----
Modifying Steinmetz' micro-tcn code so we can load the kind of multichannel audio we want.
The difference is that now, we group files that are similar except for a hyphen-designation,
e..g. input_235-1_.wav, input_235-2_.wav get read into one tensor.
The 'trick' will be that we only ever store one filename 'version' of a group of files, but whenever we
want to try to load that file, we will also grab all its associated files.
Like SignalTrain LA2A dataset only more general"""
def __init__(self, root_dir, subset="train", length=16384, preload=False, half=True, fraction=1.0, use_soundfile=False):
"""
Args:
root_dir (str): Path to the root directory of the SignalTrain dataset.
subset (str, optional): Pull data either from "train", "val", "test", or "full" subsets. (Default: "train")
length (int, optional): Number of samples in the returned examples. (Default: 40)
preload (bool, optional): Read in all data into RAM during init. (Default: False)
half (bool, optional): Store the float32 audio as float16. (Default: True)
fraction (float, optional): Fraction of the data to load from the subset. (Default: 1.0)
use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
"""
self.root_dir = root_dir
self.subset = subset
self.length = length
self.preload = preload
self.half = half
self.fraction = fraction
self.use_soundfile = use_soundfile
if self.subset == "full":
self.target_files = glob.glob(os.path.join(self.root_dir, "**", "target_*.wav"))
self.input_files = glob.glob(os.path.join(self.root_dir, "**", "input_*.wav"))
else:
# get all the target files files in the directory first
self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
self.input_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))
self.examples = []
self.minutes = 0 # total number of hours of minutes in the subset
# ensure that the sets are ordered correctlty
self.target_files.sort()
self.input_files.sort()
# get the parameters
self.params = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]
# SHH: HERE is where we'll package similar hyphen-designated files together. list comprehension here wouldn't be good btw.
# essentially we are removing 'duplicates'. the first file of each group will be the signifier of all of them
self.target_files_all, self.input_files_all = self.target_files.copy(), self.input_files.copy() # save a copy of original list
for x in self.target_files: # remove extra accompanying tracks from main list that loader will use
get_accompanying_tracks(x, self.target_files, remove=True)
for x in self.input_files:
get_accompanying_tracks(x, self.input_files, remove=True)
# make a dict that will map main file name to list of accompanying files (including itself)
self.target_accomp = {f: get_accompanying_tracks(f, self.target_files_all) for f in self.target_files}
self.input_accomp = {f: get_accompanying_tracks(f, self.input_files_all) for f in self.input_files}
# loop over files to count total length
for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):
ifile_id = int(os.path.basename(ifile).split("_")[1])
tfile_id = int(os.path.basename(tfile).split("_")[1])
if ifile_id != tfile_id:
raise RuntimeError(f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset.")
md = torchaudio.info(tfile)
num_frames = md.num_frames
if self.preload:
sys.stdout.write(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r")
sys.stdout.flush()
input, sr = self.load_accompanying(ifile, self.input_accomp)
target, sr = self.load_accompanying(tfile, self.target_accomp)
num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
if input.shape[-1] != target.shape[-1]:
print(os.path.basename(ifile), input.shape[-1], os.path.basename(tfile), target.shape[-1])
raise RuntimeError("Found potentially corrupt file!")
if self.half:
input = input.half()
target = target.half()
else:
input = None
target = None
# create one entry for each patch
self.file_examples = []
for n in range((num_frames // self.length)):
offset = int(n * self.length)
end = offset + self.length
self.file_examples.append({"idx": idx,
"target_file" : tfile,
"input_file" : ifile,
"input_audio" : input[:,offset:end] if input is not None else None,
"target_audio" : target[:,offset:end] if input is not None else None,
"params" : params,
"offset": offset,
"frames" : num_frames})
# add to overall file examples
self.examples += self.file_examples
# use only a fraction of the subset data if applicable
if self.subset == "train":
classes = set([ex['params'] for ex in self.examples])
n_classes = len(classes) # number of unique compressor configurations
fraction_examples = int(len(self.examples) * self.fraction)
n_examples_per_class = int(fraction_examples / n_classes)
n_min_total = ((self.length * n_examples_per_class * n_classes) / md.sample_rate) / 60
n_min_per_class = ((self.length * n_examples_per_class) / md.sample_rate) / 60
print(sorted(classes))
print(f"Total Examples: {len(self.examples)} Total classes: {n_classes}")
print(f"Fraction examples: {fraction_examples} Examples/class: {n_examples_per_class}")
print(f"Training with {n_min_per_class:0.2f} min per class Total of {n_min_total:0.2f} min")
if n_examples_per_class <= 0:
raise ValueError(f"Fraction `{self.fraction}` set too low. No examples selected.")
sampled_examples = []
for config_class in classes: # select N examples from each class
class_examples = [ex for ex in self.examples if ex["params"] == config_class]
example_indices = np.random.randint(0, high=len(class_examples), size=n_examples_per_class)
class_examples = [class_examples[idx] for idx in example_indices]
extra_factor = int(1/self.fraction)
sampled_examples += class_examples * extra_factor
self.examples = sampled_examples
self.minutes = ((self.length * len(self.examples)) / md.sample_rate) / 60
# we then want to get the input files
print(f"Located {len(self.examples)} examples totaling {self.minutes:0.2f} min in the {self.subset} subset.")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if self.preload:
audio_idx = self.examples[idx]["idx"]
offset = self.examples[idx]["offset"]
input = self.examples[idx]["input_audio"]
target = self.examples[idx]["target_audio"]
else:
offset = self.examples[idx]["offset"]
input_name = self.examples[idx]["input_file"]
target_name = self.examples[idx]["target_file"]
input = torch.empty((len(self.input_accomp[input_name]), self.length))
for c, fname in enumerate(self.input_accomp[input_name]):
input[c], sr = torchaudio.load(fname,
num_frames=self.length,
frame_offset=offset,
normalize=False)
target = torch.empty((len(self.target_accomp[target_name]), self.length))
for c, fname in enumerate(self.target_accomp[target_name]):
target[c], sr = torchaudio.load(fname,
num_frames=self.length,
frame_offset=offset,
normalize=False)
if self.half:
input = input.half()
target = target.half()
# at random with p=0.5 flip the phase
if np.random.rand() > 0.5:
input *= -1
target *= -1
# then get the tuple of parameters
params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
params[:,1] /= 100
return input, target, params
def load(self, filename):
if self.use_soundfile:
x, sr = sf.read(filename, always_2d=True)
x = torch.tensor(x.T)
else:
x, sr = torchaudio.load(filename, normalize=False)
return x, sr
def load_accompanying(self, filename, accomp_dict):
accomp = accomp_dict[filename]
self.num_channels = len(accomp)
md = torchaudio.info(filename) # TODO:fix: assumes all accompanying tracks are the same shape, etc!
num_frames = md.num_frames
data = torch.empty((self.num_channels,num_frames))
for c, afile in enumerate(accomp):
data[c], sr = self.load(afile)
return data, sr
'''
class Args(object): # stand-in for parseargs. these are all micro-tcn defaults
model_type ='tcn'
root_dir = str(path)
preload = False
sample_rate = 44100
shuffle = True
train_subset = 'train'
val_subset = 'val'
train_length = 65536
train_fraction = 1.0
eval_length = 131072
batch_size = 8 # original is 32, my laptop needs smaller, esp. w/o half precision
num_workers = 4
precision = 32 # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
n_params = 2
args = Args()
#if args.precision == 16: torch.set_default_dtype(torch.float16)
# setup the dataloaders
train_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
subset=args.train_subset,
fraction=args.train_fraction,
half=True if args.precision == 16 else False,
preload=args.preload,
length=args.train_length)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
shuffle=args.shuffle,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
val_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
preload=args.preload,
half=True if args.precision == 16 else False,
subset=args.val_subset,
length=args.eval_length)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
If the user requested fp16 precision then we need to install NVIDIA apex:
if False and args.precision == 16:
%pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
from apex.fp16_utils import convert_network
from microtcn.tcn_bare import TCNModel as TCNModel
#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop
class TCNModel_fastai(TCNModel):
"For fastai's sake, unpack the inputs and params"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, p=None):
if (p is None) and (self.nparams > 0): # unpack the params if needed
assert len(list(x.size())) == 3 # sanity check
x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
return super().forward(x, p=p)
# micro-tcn defines several different model configurations. I just chose one of them.
train_configs = [
{"name" : "TCN-300",
"model_type" : "tcn",
"nblocks" : 10,
"dilation_growth" : 2,
"kernel_size" : 15,
"causal" : False,
"train_fraction" : 1.00,
"batch_size" : args.batch_size
}
]
dict_args = train_configs[0]
dict_args["nparams"] = 2
model = TCNModel_fastai(**dict_args)
dtype = torch.float32
Let's take a look at the model:
# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
torchsummary.summary(model, [(1,args.train_length)], device="cpu")
else:
torchsummary.summary(model, [(1,args.train_length),(1,2)], device="cpu")
Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.
TODO: Zach says I should either use
fastaiorfastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)
# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))
dls = DataLoaders(train_dataloader, val_dataloader)
if args.precision==16:
dtype = torch.float16
model = convert_network(model, torch.float16)
model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
print("We're using Hawley's modified code")
packed, targ = dls.one_batch()
inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
print("We're using Christian's version of Dataloader and model")
inp, targ, params = dls.one_batch()
pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred = {pred.size()}")
We can make the pred and target the same length by cropping when we compute the loss:
class Crop_Loss:
"Crop target size to match preds"
def __init__(self, axis=-1, causal=False, reduction="mean", func=nn.L1Loss):
store_attr()
self.loss_func = func()
def __call__(self, pred, targ):
targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
#pred, targ = TensorBase(pred), TensorBase(targ)
assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False):
targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
return ((pred - targ)**2).mean()
wandb.login()
class WandBAudio(Callback):
"""Progress-like callback: log audio to WandB"""
order = ProgressCallback.order+1
def __init__(self, n_preds=5, sample_rate=44100):
store_attr()
def after_epoch(self):
if not self.learn.training:
with torch.no_grad():
preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
log_dict = {}
for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
wandb.log(log_dict)
wandb.init(project='micro-tcn-fastai')# no name, name=json.dumps(dict_args))
learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
cbs= [WandbCallback()])
We can use the fastai learning rate finder to suggest a learning rate:
learn.lr_find(end_lr=0.1)
And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)
epochs = 20 # change to 50 for better results but a longer wait
learn.fit_one_cycle(epochs, lr_max=3e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
wandb.finish() # call wandb.finish() after training or your logs may be incomplete
learn.save('micro-tcn-fastai')
Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:
...ok it looks like the WandB results iframe (with cool graphs & audio) is getting filtered out of the docs (by nbdev and/or jekyll), but if you open this notebook file -- e.g. click the "Open in Colab" badge at the top -- then scroll down and you'll see the report. Or just go to the WandB link posted above!
test_dataset = SignalTrainLA2ADataset_fastai(args.root_dir,
preload=args.preload,
half=True if args.precision == 16 else False,
subset='test',
length=args.eval_length)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('micro-tcn-fastai')
^^ 9 examples? I thought there were only 3:
!ls {path}/Test
...Ok I don't understand that yet. Moving on:
Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:
print(args.train_length, args.eval_length)
Handy routine to grab some data and run it through the model to get predictions:
def get_pred_batch(dataloader, crop_target=True, causal=False):
packed, target = next(iter(dataloader))
input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
pred = model.forward(packed.to('cuda:0', dtype=dtype))
if crop_target: target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
return input, params, target, pred
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = 0 # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"prediction:")
show_audio(pred[i], sample_rate)
print(f"target:")
show_audio(target[i], sample_rate)
TODO: More. We're not finished. I'll come back and add more to this later.
Check out Christian's GitHub page for micro-tcn where he provides instructions and JUCE files by which to render the model as an audio plugin. Pretty sure you can only do this with the causal models, which I didn't include -- yet!